/*
 * Javlov - a Java toolkit for reinforcement learning with multi-agent support.
 * 
 * Copyright (c) 2009 Matthijs Snel
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package net.javlov.world.grid;

import java.awt.Point;
import java.awt.Shape;
import java.awt.geom.Point2D;
import java.awt.geom.Rectangle2D;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import net.javlov.Action;
import net.javlov.Agent;
import net.javlov.Environment;
import net.javlov.RewardFunction;
import net.javlov.State;
import net.javlov.world.AgentBody;
import net.javlov.world.Body;
import net.javlov.world.CollisionEvent;
import net.javlov.world.CollisionListener;
import net.javlov.world.World;
import net.javlov.world.phys2d.Phys2DBody;
import net.javlov.world.phys2d.Phys2DWorld;
import net.phys2d.math.ROVector2f;

/**
 * Grid world implementation that assumes all moving bodies fit within one grid cell.
 * @author matthijs
 *
 */
public class GridWorld implements World.Discrete, IGridWorld {

	/**
	 * Maps agents to their bodies. Done this way because an agent is not allowed to have access
	 * to its own environment/body.
	 */
	protected Map<Agent, AgentBody> agentBodyMap;
	
	/**
	 * All the bodies in the world.
	 */
	protected List<Body> bodies;
	
	/**
	 * The grid.
	 */
	protected Grid grid;
	
	protected boolean episodeEnd;
	
	private List<CollisionListener> listeners;
	
	protected GridRewardFunction reward;

	//single-agent hack for speed
	protected State lastState;

	public GridWorld(int width, int height, double cellWidth, double cellHeight) {
		agentBodyMap = new HashMap<Agent, AgentBody>();
		bodies = new ArrayList<Body>();
		grid = new Grid(width, height, cellWidth, cellHeight);
		listeners = new ArrayList<CollisionListener>();
	}
	
	@Override
	public boolean add(Agent a, AgentBody b) {
		if ( addBody(b) ) {
			agentBodyMap.put(a, b);
			return true;
		}
		return false;
	}

	@Override
	public boolean addBody(Body b) {
		if ( bodies.add(b) ) {
			addToCells(b);
			return true;
		}
		return false;
	}

	public void addCollisionListener(CollisionListener listener) {
		listeners.add(listener);
	}
	
	@Override
	public Body getAgentBody(Agent a) {
		return agentBodyMap.get(a);
	}

	@Override
	public double getHeight() {
		return grid.getHeight()*grid.getCellHeight();
	}

	/**
	 * @inheritDoc
	 */
	@Override
	public List<Body> getIntersectingObjects(Shape s) {
		ArrayList<Body> objects = new ArrayList<Body>();
		Rectangle2D bounds = s.getBounds2D();
		Set<Body> occupiers;
		if ( bounds.getWidth() > grid.getCellWidth() || bounds.getHeight() > grid.getCellHeight() )
			occupiers = getOccupiers( getIntersectingCellsLarge(s) );
		else
			occupiers = getOccupiers( getIntersectingCells(s) );
		
		for ( Body b : occupiers ) {
			if ( s.intersects(b.getFigure().getBounds2D()) )
				objects.add(b);
		}
		return objects;
	}
	
	@Override
	public List<Body> getObjects() {
		return bodies;
	}

	@Override
	public double getTimeStep() {
		// TODO Auto-generated method stub
		return 0;
	}

	@Override
	public double getWidth() {
		return grid.getWidth()*grid.getCellWidth();
	}

	/**
	 * @inheritDoc
	 */
	@Override
	public boolean intersectsObject(Shape s) {
		for ( Body b : getOccupiers( getIntersectingCells(s) ) )
			if ( s.intersects(b.getFigure().getBounds2D()) )
				return true;
		return false;
	}
	
	protected List<GridCell> getIntersectingCells(Shape s) {
		//IMPORTANT: the following assumes the shape is SMALLER
		//than the width & height of a cell!
		Rectangle2D bounds = s.getBounds2D();
		double cx = bounds.getCenterX(),
				cy = bounds.getCenterY();
		List<GridCell> intersectingCells = new ArrayList<GridCell>(5);
		
		GridCell cell = grid.getCell(cx, cy);
		intersectingCells.add(cell);
		//if the shape is not completely contained within the current cell, check for other
		//intersecting cells
		if ( !cell.contains(bounds) ) {
			GridCell[] neighbours = cell.getQuadrantNeightbours(cx, cy);
			for ( int j = 0; j < neighbours.length; j++ )
				if ( neighbours[j].intersects(bounds) )
					intersectingCells.add(neighbours[j]);
		}
		return intersectingCells;
	}
	
	protected List<GridCell> getIntersectingCellsLarge(Shape s) {
		Rectangle2D bounds = s.getBounds();
		List<GridCell> boundsCells = new ArrayList<GridCell>();
		
		double 	minx = Math.max(bounds.getMinX(), 0),
				maxx = Math.min(bounds.getMaxX()+grid.getCellWidth(), grid.getWidth()*grid.getCellWidth()),
				miny = Math.max(bounds.getMinY(), 0),
				maxy = Math.min(bounds.getMaxY()+grid.getCellHeight(), grid.getHeight()*grid.getCellHeight());
				
		for ( double x = minx; x < maxx; x += grid.getCellWidth() )
			for ( double y = miny; y < maxy; y += grid.getCellHeight() )
				boundsCells.add( grid.getCell(x, y) );
		
		List<GridCell> intersectingCells = new ArrayList<GridCell>(boundsCells.size());
		for ( GridCell cell : boundsCells )
			if ( s.intersects(cell) )
				intersectingCells.add(cell);
		
		return intersectingCells;
	}
	
	protected Set<Body> getOccupiers(Collection<? extends GridCell> cells) {
		Set<Body> occupiers = new HashSet<Body>( (int) Math.ceil(cells.size() / 0.75) );
		for ( GridCell c : cells )
			occupiers.addAll(c.getOccupiers());
		return occupiers;
	}
	
	@Override
	public boolean remove(Agent a) {
		Body b = agentBodyMap.get(a);
		if ( removeBody(b) )
			return (agentBodyMap.remove(a) == null ? false : true);
		return false;
	}

	@Override
	public boolean removeBody(Body b) {
		if ( bodies.remove(b) ) {
			removeFromCells(b);
			return true;
		}
		return false;
	}

	@Override
	public double executeAction(Action act, Agent a) {
		/*reward.preAction(act, a, lastState);
		act.execute(a);
		lastState = agentBodyMap.get(a).getState(a);
		lastState.setTerminal(episodeEnd);
		return reward.calculateReward(lastState);*/
		//State s = agentBodyMap.get(a).getState(a);
		//s.setTerminal(episodeEnd);
		State s = lastState;
		//System.out.println("Preact: " + s);
		//System.out.println(act.getClass() + " " + act.getID());
		reward.preAction(act, a, s);
		act.execute(a);
		
		lastState = constructObservation(a);
		s = lastState;
		//System.out.println("Postact: " + s);
		
		double r = reward.calculateReward(s);
		//System.out.println("Reward: " + r);
		
		return r;
	}

	protected State constructObservation(Agent a) {
		State s = agentBodyMap.get(a).getObservation(a);
		s.setTerminal(episodeEnd);
		return s;
	}
	
	/**
	 * Returns the state by calling {@link AgentBody#getObservation(Agent)} on the agent's body.
	 */
	@Override
	public State getObservation(Agent a) {
		if ( lastState == null ) {
			lastState = constructObservation(a);
		}
		//System.out.println("Getstate: " + lastState + ":" + lastState.isTerminal());
		return lastState;
	}

	/**
	 * Returns the state dim as indicated by {@link AgentBody#getObservationDim()}
	 */
	@Override
	public int getObservationDim() {
		Iterator<AgentBody> it = agentBodyMap.values().iterator();
		return it.next().getObservationDim();
	}

	@Override
	public void init() {
		episodeEnd = false;
		lastState = null;
		//remove everything from grid
		for ( Body b : bodies )
			removeFromCells(b);
		//put everything back
		randomlyPositionAll();
		
		for ( Environment env : agentBodyMap.values() )
			env.init();
	}

	@Override
	public void reset() {
		episodeEnd = false;
		lastState = null;
		//remove and reallocate agents
		for ( Body b : agentBodyMap.values() )
			removeFromCells(b);
		
		for ( Body b : agentBodyMap.values() )
			setRandomPosition(b);
		
		for ( Environment env : agentBodyMap.values() )
			env.reset();
		/*Body b = agentBodyMap.values().iterator().next();
		GridCell currCell = grid.getCell(b.getX(), b.getY());
		currCell.removeBody(b);
		GridCell startCell = grid.getCell(0,0);
		if ( startCell.getOccupiers().size() == 0 ) {
			startCell.addBody(b);
			b.setLocation(startCell.getCenterX(), startCell.getCenterY());
		}
		else {
			startCell = grid.getCell(0,1);
			startCell.addBody(b);
			b.setLocation(startCell.getCenterX(), startCell.getCenterY());
		}*/
	}

	/**
	 * Rotates the body without checking for collisions (since body is assumed to fit in
	 * one cell). Always returns true.
	 */
	@Override
	public boolean rotateBody(Body b, double angle) {
		b.setBearing(b.getBearing()+angle);
		return true;
	}

	/**
	 * Doesn't do anything. After initialisation of the world, the grid cannot be changed.
	 */
	@Override
	public void setHeight(int height) {}

	/**
	 * Doesn't do anything. After initialisation of the world, the grid cannot be changed.
	 */
	@Override
	public void setWidth(int width) {}

	@Override
	public boolean translateBody(Body b, int dx, int dy) {
		int absx = Math.abs(dx),
			absy = Math.abs(dy);
		if ( absx > 0 && absy > 0 && absx != absy )
			throw new IllegalArgumentException("GridWorld: can only move in straight "
					+ "line, or in diagonal such that dx=dy.");
		
		//now move in unit steps to see if there is anything in the body's path
		int unitdx = (absx == 0 ? 0 : dx/absx),
			unitdy = (absy == 0 ? 0 : dy/absy),
			steps = (absx+absy) / (unitdx+unitdy);
		return translateBody(b, Direction.get(unitdx, unitdy), steps);		
	}
	
	public boolean translateBody(Body b, Direction d, int speed) {
		GridCell	origCell = grid.getCell(b.getX(), b.getY()),
					currCell = origCell,
					targetCell;
		int i;
		for ( i = 0; i < speed; i++) {
			targetCell = currCell.go(d);
			if ( targetCell.isBorder() || !move(b, d, targetCell) )
				break;
			currCell = targetCell;
		}
		if ( i > 0 ) {
			origCell.removeBody(b);
			currCell.addBody(b);
			b.setLocation(currCell.getCenterX(), currCell.getCenterY());
			return true;
		}
		return false;
	}
	
	protected boolean move( Body b, Direction d, GridCell targetCell ) {
		//TODO Don't like all these fors and ifs
		List<Body> occupiers = targetCell.getOccupiers();
		for ( Body targetBody : occupiers )
			if ( targetBody.getType() == Body.OBSTACLE ) {
				fireCollisionEvent(b, targetBody, new Point2D.Double(d.x(), d.y()));
				return false;
			}
		for ( Body targetBody : occupiers )
			if ( targetBody.getType() == Body.MOVABLE ) {
				fireCollisionEvent(b, targetBody, new Point2D.Double(d.x(), d.y()));
				if ( !translateBody(targetBody, d, 1) )
					return false;
			}
		//just create collisionevent for the other body types
		for ( Body targetBody : occupiers )
			if ( targetBody.getType() != Body.OBSTACLE && targetBody.getType() != Body.MOVABLE ) {
				fireCollisionEvent(b, targetBody, new Point2D.Double(d.x(), d.y()));
			}
				
		return true;
	}

	@Override
	public Grid getGrid() {
		return grid;
	}
	
	
	public RewardFunction getRewardFunction() {
		return reward;
	}

	public void setRewardFunction(GridRewardFunction reward) {
		this.reward = reward;
		reward.setRewardBroker(new RewardBrokerImpl());
	}
	
	protected void addToCells(Body b) {
		int count = 0;
    	for ( GridCell c : getIntersectingCellsLarge(b.getFigure()) ) {
    		count++;
    		c.addBody(b);
    	}
    }
	
	protected void removeFromCells(Body b) {
		for ( GridCell c : getIntersectingCellsLarge(b.getFigure()) )
    		c.removeBody(b);
    }
	
	protected void fireCollisionEvent(Body b1, Body b2, Point2D.Double speed) {
		CollisionEvent event = new CollisionEvent(b1, b2, speed, (Point2D.Double)b2.getLocation());
		fireCollisionEvent(event);
	}
	
	protected void fireCollisionEvent(CollisionEvent event) {
		for ( CollisionListener listener : listeners )
			listener.collisionOccurred(event);
	}
	
	protected void randomlyPositionAll() {
		List<Point> freePositions = new ArrayList<Point>(grid.getWidth()*grid.getHeight());
		for ( int x = 0; x < grid.getWidth(); x++ )
			for ( int y = 0; y < grid.getHeight(); y++ )
				freePositions.add( new Point(x,y) );
		
		//randomly put everything back
		for ( Body b : bodies )
			/*if ( b.getType() == Body.OBSTACLE )
				testClosedLoop(b);
			else*/
				setRandomPositionFromList(b, freePositions);
	}
	
	protected void testClosedLoop(Body b) {
		double cellwidth = grid.getCellWidth(),
				cellheight = grid.getCellHeight();
		Rectangle2D bounds = b.getFigure().getBounds2D();
		int bodywidth = (int)Math.ceil((bounds.getWidth()-1) / cellwidth),
			bodyheight = (int)Math.ceil((bounds.getHeight()-1) / cellheight);
		int x = 0, y = 0;
		GridCell cell;
		List<GridCell> cells = new ArrayList<GridCell>();
		if ( bodywidth > 1 ) {
			x = 0; y = 7;
			System.out.println("==================== Horizontal body");
			cell = grid.getCell(0, 700);
			cells.add(cell);
			cell = grid.getCell(100, 700);
			cells.add(cell);
			for ( GridCell c : cells )
				c.addBody(b);
		} else {
			x = 2; y = 8;
			System.out.println("==================== Vertical body");
			cell = grid.getCell(200, 800);
			cells.add(cell);
			cell = grid.getCell(200, 900);
			cells.add(cell);
			for ( GridCell c : cells )
				c.addBody(b);
		}
		
		System.out.println("Closed loop: " + closedLoop(b, cells));
		
		
		b.setLocation( (x + 0.01)*cellwidth + 0.5*bounds.getWidth(), 
				(y + 0.01)*cellheight + 0.5*bounds.getHeight());
	}
	
	protected void setRandomPositionFromList(Body b, List<Point> freePositions) {
		
		double cellwidth = grid.getCellWidth(),
				cellheight = grid.getCellHeight();
		Rectangle2D bounds = b.getFigure().getBounds();
		int bodywidth = (int)Math.ceil((bounds.getWidth()-1) / cellwidth),
			bodyheight = (int)Math.ceil((bounds.getHeight()-1) / cellheight),
			gridwidth = grid.getWidth(),
			gridheight = grid.getHeight();

		Point pick;
		List<GridCell> cells = new ArrayList<GridCell>();
		List<Point> points = new ArrayList<Point>();
		GridCell cell;
		boolean occupied;
		int counter = 0;
		
		do {
			counter++;
			for ( GridCell c : cells )
				c.removeBody(b);
			
			cells.clear();
			points.clear();
			occupied = false;
			
			do {
				pick = freePositions.get( (int)(Math.random()*freePositions.size()) );
			} while (pick.x + bodywidth > gridwidth || pick.y + bodyheight > gridheight);
				
			for ( int x = 0; x < bodywidth; x++ ) {
				for ( int y = 0; y < bodyheight; y++ ) {
					cell = grid.getCell((pick.x + x)*cellwidth+1, (pick.y + y)*cellheight+1);
					if ( cell.getOccupiers().size() > 0 ) {
						occupied = true;
						break;
					}
					cell.addBody(b);
					cells.add(cell);
					points.add( new Point(pick.x + x, pick.y + y) );
				}
				if ( occupied )
					break;
			}
		} while (occupied || closedLoop(b, cells));
		
		b.setLocation( (pick.x + 0.01)*cellwidth + 0.5*bounds.getWidth(), 
				(pick.y + 0.01)*cellheight + 0.5*bounds.getHeight());
		
		//b.setLocation( (pick.x + 0.5)*cellwidth, 
		//		(pick.y + 0.5)*cellheight);
		
		freePositions.removeAll(points);
		
		//if ( counter > 20 )
		//	System.out.println("Warning. More than 20 tries: " + counter);		
	}
	
	protected void setRandomPosition(Body b) {
		int width = grid.getWidth(),
			height = grid.getHeight();	
		double cellwidth = grid.getCellWidth(),
				cellheight = grid.getCellHeight();
		Rectangle2D bounds = b.getFigure().getBounds();
		int counter = 0;
		List<GridCell> cells;
		boolean occupied;
		do {
			occupied = false;
			double x = (int)(Math.random()*(width - (bounds.getWidth()-1) / cellwidth))*cellwidth,
					y = (int)(Math.random()*(height - (bounds.getHeight()-1) / cellheight))*cellheight;
			
			//TODO change mechanism to fix below problem
			//add small value to position, otherwise (through floating point errors) grid
			//sometimes thinks body is in wrong cell (if body bounds are exactly on border
			//of cells)
			//b.setLocation( x + 0.5*bounds.getWidth() + 0.01*cellwidth, 
			//		y + 0.01*cellheight + 0.5*bounds.getHeight());
			
			b.setLocation( x + 0.5*cellwidth, y + 0.5*cellheight);
					
			cells = getIntersectingCellsLarge( b.getFigure() );
			for ( GridCell c : cells )
				if ( c.getOccupiers().size() > 0 ) {
					occupied = true;						
				}
		} while ( counter++ < 100 && (occupied || closedLoop(b, cells)) );
		
		if ( counter > 100 )
			throw new RuntimeException("Could not reposition body after 100 tries.");
		
		for ( GridCell c : cells )
			c.addBody(b);
	}
	
	protected boolean closedLoop(Body b, List<GridCell> cells) {
		if ( b.getType() != Body.OBSTACLE )
			return false;
		
		GridCell[] neighbours;
		List<GridCell> checkedCells = new ArrayList<GridCell>(20);
		GridCell prevCell = null;
		for ( GridCell cell : cells ) {
			neighbours = cell.getNeighbours();
			for ( int i = 0; i < neighbours.length; i++ ) {
				for ( Body occupier : neighbours[i].getOccupiers() ) {
					if ( !neighbours[i].equals(prevCell) && occupier.getType() == Body.OBSTACLE ) {
						checkedCells.add(cell);
						if ( closedLoopSub(checkedCells, cell, neighbours[i]) )
							return true;
						break;
					}
				}
			}
			prevCell = cell;
		}

		return false;
	}
	
	protected boolean closedLoopSub(List<GridCell> checkedCells, GridCell prevCell, GridCell currCell) {
		//System.out.println("- Sub: " + checkedCells + ", " + prevCell + "," + currCell);
		//below if block also invalidates rare cases that are not a closed loop. but who cares.
		if ( currCell.isAtRim() ) {
			int dir;
			for ( dir = 0; dir < 8; dir += 2 )
				if ( currCell.isAtRim(dir) )
					break;
			boolean oneNotAtRim = false, oneAtRim = false;
			for ( GridCell cell : checkedCells ) {
				if ( cell.isAtRim() )
					oneAtRim = true;
				if ( !cell.isAtRim(dir) )
					oneNotAtRim = true;
				if ( oneNotAtRim && oneAtRim ) {
					return true;
				}
			}
			checkedCells.add(currCell);
			return false;
		}
		
		GridCell[] neighbours = currCell.getNeighbours();
		for ( int i = 0; i < neighbours.length; i++ )
			if ( !neighbours[i].equals(prevCell) && !neighbours[i].isNeighbour(prevCell) ) {
				if ( checkedCells.contains(neighbours[i]) )
					return true;
				for ( Body occupier : neighbours[i].getOccupiers() )
					if ( occupier.getType() == Body.OBSTACLE ) {
						checkedCells.add(currCell); //will lead to doubles
						if ( closedLoopSub(checkedCells, currCell, neighbours[i]) )
							return true;
						break;
					}
			}
		return false;
	}
	/*
	protected void setRandomPositionFromArray(Body b, int[][] positions) {
	
		List<GridCell> cells = new ArrayList<GridCell>();
		GridCell cell;
		int counter = 0;
		Point pick;
		
		do {
			counter++;
			for ( GridCell c : cells )
				c.removeBody(b);
			
			cells.clear();

			pick = pickPosition(b, positions);
			for ( int x = 0; x < bodywidth; x++ ) {
				for ( int y = 0; y < bodyheight; y++ ) {
					cell = grid.getCell((pick.x + x)*cellwidth+1, (pick.y + y)*cellheight+1);
					if ( cell.getOccupiers().size() > 0 ) {
						occupied = true;
						break;
					}
					cell.addBody(b);
					cells.add(cell);
					points.add( new Point(pick.x + x, pick.y + y) );
				}
				if ( occupied )
					break;
			}
		} while (occupied || closedLoop(b, cells));
		
		b.setLocation( (pick.x + 0.01)*cellwidth + 0.5*bounds.getWidth(), 
				(pick.y + 0.01)*cellheight + 0.5*bounds.getHeight());
		
		freePositions.removeAll(points);
		
		if ( counter > 20 )
			System.out.println("Warning. More than 20 tries: " + counter);		
	}
	
	protected Point pickPosition(Body b, int[][] positions) {
		double cellwidth = grid.getCellWidth(),
			cellheight = grid.getCellHeight();
		Rectangle2D bounds = b.getFigure().getBounds2D();
		int bodywidth = (int)Math.ceil((bounds.getWidth()-1) / cellwidth),
			bodyheight = (int)Math.ceil((bounds.getHeight()-1) / cellheight);

		int x = 0, y = 0;
		boolean occupied;
		do {
			occupied = false;
			x = (int)(Math.random()*(positions.length-bodywidth));
			y = (int)(Math.random()*(positions[0].length-bodyheight));
			for ( int i = x; i < x+bodywidth; i++ ) {
				for ( int j = y; j < y+bodywidth; j++ ) {
					if ( positions[i][j] == -1 ) {
						occupied = true;
						break;
					}
				}
				if ( occupied )
					break;
			}			
		} while (occupied);
		return new Point(x,y);
	}*/
	
	public interface RewardBroker {
		Map<Agent, AgentBody> getAgentBodyMap();
		List<Body> getBodies();
		Grid getGrid();
		void endEpisode();
	}
	
	protected class RewardBrokerImpl implements RewardBroker {

		@Override
		public Map<Agent, AgentBody> getAgentBodyMap() {
			return agentBodyMap;
		}

		@Override
		public List<Body> getBodies() {
			return bodies;
		}

		@Override
		public Grid getGrid() {
			return grid;
		}
		
		@Override
		public void endEpisode() {
			episodeEnd = true;
		}
	}
}
